-
Notifications
You must be signed in to change notification settings - Fork 25.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RocBert #20013
Add RocBert #20013
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding this new model! This is very clean and nice use of our Copied from mechanism!
Most of my comments are around the name: it seems the paper names the model RoCBert so let's respect the casing :-)
change RocBert -> RoCBert
@sgugger Thanks for your suggestion, I already fixed it ~ |
@ArthurZucker hi, I already fixed the code according to sgugger's advice., could you please review it, thanks! |
Yes! Doing this asap 🤗 sorry for the delay |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Really nice model! Great work its clean and interesting!
A few comments here and there, make sure the doctests pass, and would love to see a more detailed generation
test to make sure that the generate
function works properly on an integration test.
ps: really loved the use of copied from, thanks for your hard work 😄
device = labels_input_ids.device | ||
|
||
target_inputs = torch.clone(labels_input_ids) | ||
target_inputs[target_inputs == -100] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we have a config.pad_token_id
let's use it (unless it is a different padding toke)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if targets use this? Usually we rely on the -100 index since it's ignore by PyTorch loss functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if targets use this? Usually we rely on the -100 index since it's ignore by PyTorch loss functions.
In RoCBertForPreTraining model,when counting the sim_matrix between (labels_input_ids, attack_ids), we turn -100 to config.pad_token_id to get it's pooled_embed of roc_bert.
with open(word_pronunciation_file, "r", encoding="utf8") as in_file: | ||
self.word_pronunciation = json.load(in_file) | ||
|
||
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most probably a nit, cc @sgugger not sure how we feel about the collection dependency as we usually do this with native python.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
collections
is in the standard lib, so not a problem for me.
|
||
expected_slice = torch.tensor([[[0.6248, 0.3013, 0.3739], [0.3544, 0.8086, 0.2427], [0.3244, 0.6589, 0.1711]]]) | ||
|
||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to have a few more test. At least one with an expected chinese generation text / showing an attack resistance. (not sure if it makes sense, tell me if it doesn't)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed this test code, now the input text is "ba 里 系 [MASK] 国 的 首 都” which is the adversarial text of "巴 黎 是 [MASK] 国 的 首 都”, means "Paris is the [MASK] of France" in English。
In this model, we expect it can leran:
"ba里" => "巴黎"(Paris),
"系" => "是"(is),
and inference the mask word “[MASK] 国" => "法国" (France)
Last comment, it seems that the issue with naming still persists, we should make sure to either write |
add doc, add detail test
@ArthurZucker I didn't make weiweishi/roc-bert-base-zh public before, it's avaliable now, and other issues are resolved~ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again!
* add roc_bert * update roc_bert readme * code style * change name and delete unuse file * udpate model file * delete unuse log file * delete tokenizer fast * reformat code and change model file path * add RocBertForPreTraining * update docs * delete wrong notes * fix copies * fix make repo-consistency error * fix files are not present in the table of contents error * change RocBert -> RoCBert * add doc, add detail test Co-authored-by: weiweishi <weiweishi@tencent.com>
This PR adds the RocBert model.
RocBert is a pre-trained Chinese language model that is designed from the ground up to be robust against maliciously crafted adversarial texts such as misspellings, homograph attacks, and other forms of deception.
This property is crucial in downstream applications like content moderation.
RocBert differs from the classic Bert architecture in the following ways:
Since the model structure and tokenizer is quite different from existing implementations, we would like to submit this PR to add a new model class.
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.